#
# Inception-ResNet-V1 network components
# Details are in https://arxiv.org/pdf/1602.07261v2.pdf
#

ConvBNLayer {outChannels, kernel, stride, pad, bnTimeConst} = Sequential(
    ConvolutionalLayer{outChannels, kernel, init = 'heNormal', stride = stride, pad = pad, bias = false} :
    BatchNormalizationLayer{spatialRank = 2, normalizationTimeConstant = bnTimeConst}
)

ConvBNReLULayer {outChannels, kernel, stride, pad, bnTimeConst} = Sequential(
    ConvBNLayer{outChannels, kernel, stride, pad, bnTimeConst} :
    ReLU
)

#
# Figure 10 from https://arxiv.org/pdf/1602.07261v2.pdf
#
InceptionResNetA {bnTimeConst} = {
    apply(x) = {
        # 1x1 Convolution
        branch1x1 = ConvBNReLULayer{32, (1:1), (1:1), true, bnTimeConst}(x)

        # 3x3 Convolution
        branch3x3 = Sequential( 
            ConvBNReLULayer{32, (1:1), (1:1), true, bnTimeConst} :
            ConvBNReLULayer{32, (3:3), (1:1), true, bnTimeConst}
        ) (x)

        # Double 3x3 Convolution
        branch3x3dbl = Sequential(
            ConvBNReLULayer{32, (1:1), (1:1), true, bnTimeConst} :
            ConvBNReLULayer{32, (3:3), (1:1), true, bnTimeConst} :
            ConvBNReLULayer{32, (3:3), (1:1), true, bnTimeConst}
        ) (x)

        # Concat
        concat = Splice((branch1x1:branch3x3:branch3x3dbl), axis=3)
        residual = ConvBNLayer{256, (1:1), (1:1), true, bnTimeConst}(concat)

        sum = Plus(residual, x)
        out = ReLU(sum)
    }.out
}.apply

#
# Figure 11 from https://arxiv.org/pdf/1602.07261v2.pdf
#
InceptionResNetB {bnTimeConst} = {
    apply(x) = {
        # 1x1 Convolution
        branch1x1 = ConvBNReLULayer{128, (1:1), (1:1), true, bnTimeConst}(x)

        # 7x7 Convolution
        branch7x7 = Sequential( 
            ConvBNReLULayer{128, (1:1), (1:1), true, bnTimeConst} :
            ConvBNReLULayer{128, (1:7), (1:1), true, bnTimeConst} :
            ConvBNReLULayer{128, (7:1), (1:1), true, bnTimeConst}
        ) (x)

        # Concat
        concat = Splice((branch1x1:branch7x7), axis=3)
        residual = ConvBNLayer{896, (1:1), (1:1), true, bnTimeConst}(concat)

        sum = Plus(residual, x)
        out = ReLU(sum)
    }.out
}.apply

#
# Figure 19 from https://arxiv.org/pdf/1602.07261v2.pdf
#
InceptionResNetC {bnTimeConst} = {
    apply(x) = {
        # 1x1 Convolution
        branch1x1 = ConvBNReLULayer{192, (1:1), (1:1), true, bnTimeConst}(x)

        # 3x3 Convolution
        branch3x3 = Sequential( 
            ConvBNReLULayer{192, (1:1), (1:1), true, bnTimeConst} :
            ConvBNReLULayer{192, (1:3), (1:1), true, bnTimeConst} :
            ConvBNReLULayer{192, (3:1), (1:1), true, bnTimeConst}
        ) (x)

        # Concat
        concat = Splice((branch1x1:branch3x3), axis=3)
        residual = ConvBNLayer{1792, (1:1), (1:1), true, bnTimeConst}(concat)

        sum = Plus(residual, x)
        out = ReLU(sum)
    }.out
}.apply

#
# Figure 7 from https://arxiv.org/pdf/1602.07261v2.pdf
#
ReductionA {k, l, m, n, bnTimeConst} = {
    apply(x) = {
        # 3x3 Convolution
        branch3x3 = Sequential( 
            ConvBNReLULayer{n, (3:3), (2:2), false, bnTimeConst}
        ) (x)

        # Double 3x3 Convolution
        branch3x3dbl = Sequential(
            ConvBNReLULayer{k, (1:1), (1:1), true, bnTimeConst} :
            ConvBNReLULayer{l, (3:3), (1:1), true, bnTimeConst} :
            ConvBNReLULayer{m, (3:3), (2:2), false, bnTimeConst}
        ) (x)

        # Max Pooling
        branch_pool = MaxPoolingLayer{(3:3), stride = (2:2), pad = false}(x)

        out = Splice((branch3x3:branch3x3dbl:branch_pool), axis=3)
    }.out
}.apply

#
# Figure 12 from https://arxiv.org/pdf/1602.07261v2.pdf
#
ReductionB {bnTimeConst} = {
    apply(x) = {
        # 3x3 Convolution
        branch3x3_1 = Sequential( 
            ConvBNReLULayer{256, (1:1), (1:1), true, bnTimeConst} :
            ConvBNReLULayer{384, (3:3), (2:2), false, bnTimeConst}
        ) (x)

        # 3x3 Convolution
        branch3x3_2 = Sequential( 
            ConvBNReLULayer{256, (1:1), (1:1), true, bnTimeConst} :
            ConvBNReLULayer{256, (3:3), (2:2), false, bnTimeConst}
        ) (x)

        # Double 3x3 Convolution
        branch3x3dbl = Sequential(
            ConvBNReLULayer{256, (1:1), (1:1), true, bnTimeConst} :
            ConvBNReLULayer{256, (3:3), (1:1), true, bnTimeConst} :
            ConvBNReLULayer{256, (3:3), (2:2), false, bnTimeConst}
        )(x)

        # Max Pooling
        branch_pool = MaxPoolingLayer{(3:3), stride = (2:2), pad = false}(x)

        out = Splice((branch3x3_1:branch3x3_2:branch3x3dbl:branch_pool), axis=3)
    }.out
}.apply